import os.path as osp

import pandas as pd
import numpy as np
from rpy2 import robjects
from rpy2.robjects.vectors import StrVector
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter
from sklearn.base import BaseEstimator

from aif360.sklearn.utils import check_inputs, check_groups


class FairAdapt(BaseEstimator):
    """Fair Data Adaptation.

    Fairadapt is a pre-processing technique that can be used for both fair
    classification and fair regression [#plecko20]_. The method is a causal
    inference approach to bias removal and it relies on the causal graph for
    the dataset. The original implementation is in R [#plecko21]_.

    References:
        .. [#plecko20] `D. Plečko and N. Meinshausen,
           "Fair Data Adaptation with Quantile Preservation,"
           Journal of Machine Learning Research, 2020.
           <https://www.jmlr.org/papers/volume21/19-966/19-966.pdf>`_
        .. [#plecko21] `D. Plečko and N. Bennett and N. Meinshausen,
           "FairAdapt: Causal Reasoning for Fair Data Pre-processing,"
           arXiv, 2021. <https://arxiv.org/abs/2110.10200>`_

    Attributes:
        prot_attr_ (str or list(str)): Protected attribute(s) used for fair data
            adaptation.
        groups_ (array, shape (n_groups,)): A list of group labels known to the
            transformer.
    """

    def __init__(self, prot_attr, adj_mat):
        """
        Args:
            prot_attr (single label): Name of the protected attribute. Must be
                binary.
            adj_mat (array-like): A 2-dimensional array representing the
                adjacency matrix of the causal diagram of the data generating
                process. Row/column order must match `X_train`.
        """
        self.prot_attr = prot_attr
        self.adj_mat = adj_mat

        # R packages need to run FairAdapt
        pkgs = ('ranger', 'fairadapt')
        # selectively install the missing packages
        pkgs = [p for p in pkgs if not robjects.packages.isinstalled(p)]
        if len(pkgs) > 0:
            utls = robjects.packages.importr('utils')
            utls.chooseCRANmirror(ind=1)
            utls.install_packages(StrVector(pkgs))

    def fit_transform(self, X_train, y_train, X_test):
        """Remove bias from the given dataset by fair adaptation.

        Args:
            X_train (pandas.DataFrame): Training data frame (including the
                protected attribute).
            y_train (pandas.Series): Training labels.
            X_test (pandas.DataFrame): Test data frame (including the protected
                attribute).

        Returns:
            tuple:
                Transformed inputs.

                * **X_fair_train** (pandas.DataFrame) -- Transformed training
                  data.
                * **y_fair_train** (array-like) -- Transformed training labels.
                * **X_fair_test** (pandas.DataFrame) -- Transformed test data.

        """
        # merge X_train and y_train
        df_train = pd.concat([X_train, y_train], axis=1)
        groups, self.prot_attr_ = check_groups(X_train, self.prot_attr, ensure_binary=True)
        self.groups_ = np.unique(groups)

        wrapper = osp.join(osp.dirname(osp.abspath(__file__)), 'fairadapt.R')
        robjects.r.source(wrapper)
        FairAdapt_R = robjects.r['wrapper']
        # convert to Pandas with a local converter
        with localconverter(robjects.default_converter + pandas2ri.converter):
            train_data = robjects.conversion.py2rpy(df_train)
            test_data = robjects.conversion.py2rpy(X_test)
            adj_mat = robjects.conversion.py2rpy(self.adj_mat)

        # run FairAdapt in R
        res = FairAdapt_R(
            train_data=train_data,
            test_data=test_data,
            adj_mat=adj_mat,
            prot_attr=self.prot_attr_,
            outcome=y_train.name
        )

        with localconverter(robjects.default_converter + pandas2ri.converter):
            X_fair_train = robjects.conversion.rpy2py(res.rx2('train'))
            X_fair_test = robjects.conversion.rpy2py(res.rx2('test'))
        X_fair_train.columns = [y_train.name] + X_train.columns.tolist()
        y_fair_train = X_fair_train.pop(y_train.name)
        X_fair_test.columns = X_test.columns

        return X_fair_train, y_fair_train, X_fair_test
